package net.xiake6.orm.datasource.sharding;

import java.lang.reflect.Field;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import com.alibaba.druid.filter.FilterChain;
import com.alibaba.druid.filter.FilterEventAdapter;
import com.alibaba.druid.proxy.jdbc.ConnectionProxy;
import com.alibaba.druid.proxy.jdbc.PreparedStatementProxy;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLObject;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLBooleanExpr;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLNullExpr;
import com.alibaba.druid.sql.ast.expr.SQLNumericLiteralExpr;
import com.alibaba.druid.sql.ast.expr.SQLTextLiteralExpr;
import com.alibaba.druid.sql.ast.expr.SQLValuableExpr;
import com.alibaba.druid.sql.ast.statement.SQLSelect;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.mysql.jdbc.JDBC4CallableStatement;
import com.mysql.jdbc.JDBC4PreparedStatement;
import com.mysql.jdbc.JDBC4ServerPreparedStatement;
import com.mysql.jdbc.StringUtils;

import lombok.extern.slf4j.Slf4j;
import net.xiake6.orm.datasource.HandleDataSource;

/**
 * 对SQL语句中的表名替换为分表的实现 
 * ClassName DynamicTableFilter.java
 * 
 * @author fenglibin
 * @Blog http://xiake6.net
 * @Date 2019年12月11日
 * 
 *       Description
 */
@Slf4j
@Service
public class DynamicTableFilter extends FilterEventAdapter {
	@Autowired
	private ShardingConfig shardingConfig;

	@Override
	public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql)
			throws SQLException {
		log.info("原SQL:" + sql);
		return super.connection_prepareStatement(chain, connection, sql);

	}

	@Override
	public boolean preparedStatement_execute(FilterChain chain, PreparedStatementProxy statement) throws SQLException {
		// JDBC4PreparedStatement ps;
		try {

			// Field rawField =
			// statement.getClass().getSuperclass().getSuperclass().getDeclaredField("raw");
			// rawField.setAccessible(true);
			// ps = (JDBC4PreparedStatement) rawField.get(statement);

			PreparedStatement ps = statement.getRawObject();

			String sql = getSql(ps);
			log.info("补充了参数值的原始SQL：" + sql);

			// 修改PreparedStatement中的SQL需要修改两个变量：<br>
			// 1、originalSql，该变量记录的是完整的SQL，已经替换了SQL中的变量，是可以直接被执行的语句，<br>
			// 不过该语句起展示性作用，真正执行的SQL不是这个变量中的值，不过可以用于SQL展示。<br>

			int tableShadingId = getTableShadingId(sql);
			sql = TableParser.getShardingTableSql(sql, shardingConfig.getShardingTables(), tableShadingId);

			Field originalSqlField = ps.getClass().getSuperclass().getDeclaredField("originalSql");
			originalSqlField.setAccessible(true);
			String originalSql = (String) originalSqlField.get(ps);
			log.debug("未补充参数值的原始SQL：" + originalSql);
			originalSqlField.set(ps, sql);
			originalSql = (String) originalSqlField.get(ps);
			log.debug("未补充参数值、修改表名为分表后的原始SQL：" + originalSql);

			// 2、parseInfo中的staticSql属性，parseInfo保存了SQL被解析后的基本信息，变量staticSql存放的是未补充参数的原始语句，<br>
			// staticSql的SQL才是真正会被执行的语句，在执行前还会对?代表的变量进行值的替换。
			Field parseInfoField = ps.getClass().getSuperclass().getDeclaredField("parseInfo");
			parseInfoField.setAccessible(true);
			Object parseInfoObj = parseInfoField.get(ps);
			Field staticSqlField = parseInfoObj.getClass().getDeclaredField("staticSql");
			staticSqlField.setAccessible(true);
			byte[][] staticSqlStrings = (byte[][]) staticSqlField.get(parseInfoObj);
			// 循环替换所有的语句的表名为Sharding表名
			for (int i = 0; i < staticSqlStrings.length; i++) {
				String tSql = StringUtils.toString(staticSqlStrings[i]);
				tSql = TableParser.getShardingTableSql(tSql, shardingConfig.getShardingTables(), tableShadingId);
				staticSqlStrings[i] = StringUtils.getBytes(tSql);
			}
			sql = getSql(ps);
			log.info("指行SQL的数据源为：" + HandleDataSource.getDataSource() + "，最终会被执行的SQL：" + sql);

		} catch (Exception e) {
			log.error(e.getMessage(), e);
		}
		return super.preparedStatement_execute(chain, statement);
	}

	/**
	 * 通过PreparedStatement的对象去获取SQL
	 * 
	 * @param ps
	 * @return
	 * @throws SQLException
	 */
	private String getSql(PreparedStatement ps) throws SQLException {
		String sql = null;
		if (ps instanceof JDBC4PreparedStatement) {
			sql = ((JDBC4PreparedStatement) ps).asSql();
		} else if (ps instanceof JDBC4CallableStatement) {
			sql = ((JDBC4CallableStatement) ps).asSql();
		} else if (ps instanceof JDBC4ServerPreparedStatement) {
			sql = ((JDBC4ServerPreparedStatement) ps).asSql();
		}
		return sql;
	}

	/**
	 * 返回新旧表名的映射map
	 * 
	 * @param whereExpr
	 * @param tableSource
	 * @return
	 */
	private int getTableShadingId(String sql) {
		Object obj = null;
		// 解析sql
		MySqlStatementParser parser = new MySqlStatementParser(sql);
		SQLStatement stmt = parser.parseStatement();
		Map<String, Object> map = stmt.getAttributes();
		map.forEach((k, v) -> {
			log.info("key is:" + k + ",value is:" + v);
		});

		if (stmt instanceof SQLSelectStatement) {
			SQLSelect select = ((SQLSelectStatement) stmt).getSelect();
			obj = select.getQueryBlock().getWhere();
		} else if (stmt instanceof MySqlUpdateStatement) {
			MySqlUpdateStatement update = (MySqlUpdateStatement) stmt;
			obj = update.getWhere();
		} else if (stmt instanceof MySqlDeleteStatement) {
			MySqlDeleteStatement delete = (MySqlDeleteStatement) stmt;
			obj = delete.getWhere();
		} else if (stmt instanceof MySqlInsertStatement) {
			MySqlInsertStatement insert = (MySqlInsertStatement) stmt;
			obj = insert;
		}

		return getTableShadingId(obj);
	}

	/**
	 * 返回新旧表名的映射map
	 * 
	 * @param obj
	 * @param tableSource
	 * @return
	 */
	private int getTableShadingId(Object obj) {
		// 存放字段的名称与值的Map
		Map<String, Object> fieldNameAndValue = getKeyValue(obj);
		int tableShadingId = shardingConfig.getTableShardingCondition().getShardingId(fieldNameAndValue);
		return tableShadingId;
	}

	/**
	 * 获取SQLExpr代表的Where条件语句组成的Key-Value
	 * Map，或者MySqlInsertStatement代表的插入语句中的字值与值的Key-Value Map
	 * 
	 * @param obj
	 * @return
	 */
	private Map<String, Object> getKeyValue(Object obj) {
		if (obj instanceof SQLExpr) {
			return getWhereKeyValue((SQLExpr) obj);
		} else if (obj instanceof MySqlInsertStatement) {
			return getInsertKeyValue((MySqlInsertStatement) obj);
		}
		return null;
	}

	/**
	 * 将Where条件中的条件组装成Map返回
	 * 
	 * @param sqlExpr
	 * @return
	 */
	private Map<String, Object> getWhereKeyValue(SQLExpr sqlExpr) {
		Map<String, Object> map = new HashMap<String, Object>();
		getWhereKeyValue(map, sqlExpr);
		return map;
	}

	private void getWhereKeyValue(final Map<String, Object> map, SQLExpr sqlExpr) {
		List<SQLObject> child = sqlExpr.getChildren();
		if (child.size() == 0) {
			return;
		}
		final List<Object> list = new ArrayList<Object>();
		child.forEach(s -> {
			if (s instanceof SQLBinaryOpExpr) {
				getWhereKeyValue(map, (SQLExpr) s);
			} else if (s instanceof SQLValuableExpr) {
				list.add(((SQLValuableExpr) s).getValue());
			}
		});
		if (list.size() == 2) {
			map.put(list.get(0).toString(), list.get(1));
		}
	}

	/**
	 * 根据插入语句组成字段和值之间的Map
	 * 
	 * @param insert
	 * @return 插入语句组成字段和值之间的Map
	 */
	private Map<String, Object> getInsertKeyValue(MySqlInsertStatement insert) {
		Map<String, Object> map = new HashMap<String, Object>();
		final List<String> columnList = new ArrayList<String>();
		insert.getColumns().forEach(c -> {
			if (c instanceof SQLIdentifierExpr) {
				SQLIdentifierExpr expr = (SQLIdentifierExpr) c;
				columnList.add(expr.normalizedName());
			}
		});
		final List<Object> valueList = new ArrayList<Object>();
		insert.getValues().getValues().forEach(v -> {
			if (v instanceof SQLTextLiteralExpr) {
				SQLTextLiteralExpr expr = (SQLTextLiteralExpr) v;
				valueList.add(expr.getText());
			} else if (v instanceof SQLNumericLiteralExpr) {
				SQLNumericLiteralExpr expr = (SQLNumericLiteralExpr) v;
				valueList.add(expr.getNumber().toString());
			} else if (v instanceof SQLNullExpr) {
				valueList.add(null);
			} else if (v instanceof SQLBooleanExpr) {
				SQLBooleanExpr expr = (SQLBooleanExpr) v;
				valueList.add(expr.getValue());
			}
		});
		if (columnList.size() != valueList.size()) {
			throw new RuntimeException(
					"Insert语句中的字段名称的数量" + columnList.size() + " 与值的的数量:" + valueList.size() + " 不相等！");
		}
		for (int i = 0; i < columnList.size(); i++) {
			map.put(columnList.get(i), valueList.get(i));
		}
		return map;
	}

}
